import os
from config import *
from utils import process, t_plot, write_log
import torch

txt_list = []
all_data = data[:, output_n + 1]
Dte = process(test_data, 1, False, interval, pred_size, output_n)
test_num = int(int(len(data) * (1 - test_rate)))
train_data = data[:int(len(data) * (1 - val_rate - test_rate))]
val_data = data[int(len(data) * (1 - val_rate - test_rate)):test_num]
test_data = data[test_num:]
model = torch.load(f'{model_path}/model-{r_name}.pth')
write_log(f'model: {r_name}', txt_list)
y_re, y_real = t_plot(result_path, r_name, model, test_num, t_date, Dte, all_data, device, pred_size,
                      interval, txt_list)

n = np.where(y_real == 0)
print(n)
MSE = np.sum((y_re - y_real) ** 2) / len(y_real)
RMSE = np.sqrt(MSE)
MAE = np.sum(np.abs(y_re - y_real)) / len(y_real)
MAPE = np.sum(np.abs((y_re - y_real) / y_real)) / len(y_real) * 100
S = np.abs(y_re - y_real) / ((np.abs(y_re) + np.abs(y_real)) / 2)
SMAPE = np.sum(S) / len(y_real) * 100
av_y = sum(y_real) / len(y_real)
R2 = 1 - np.sum((y_re - y_real) ** 2) / np.sum((av_y - y_real) ** 2)

write_log(f'MSE: {MSE}', txt_list)
write_log(f'RMSE: {RMSE}', txt_list)
write_log(f'MAE: {MAE}', txt_list)
write_log(f'R2: {R2}', txt_list)
write_log(f'MAPE: {MAPE} %', txt_list)
write_log(f'SMAPE: {SMAPE} %', txt_list)

os.makedirs(f'{result_path}/log', exist_ok=True)
content = ''
for txt in txt_list:
    content += txt
with open(f'{result_path}/log/log-{r_name}.txt', 'w+', encoding='utf8') as f:
    f.write(content)
